# --- Packages ---
library(tidyverse)
library(readxl)
library(randomForest)
library(DirichletReg)
library(tidytext)
library(jiebaR)
library(writexl)
library(caret)
library(randomForest)
library(doParallel)
library(MLmetrics)
library(raster)
library(DirichletReg)
library(MASS)
library(grid)
library(gridExtra)
library(qqplotr)
library(glmmTMB)
library(patchwork)

# --- Load data  ---
# Local grid-center (GC) articles data
load(file = "data/grid_centers.RData")

# Stop words (UTF-8 encoded)
stopwords <- readLines("data/stopwords.txt", encoding = "UTF-8")

# --- Machine Learning ---

# Step 1: Tokenize GC articles (Chinese segmentation)
df_segmented <- data.frame(id = integer(), word = character())
for (i in 1:nrow(grid_centers)) {
  text <- grid_centers$text[i]
  # Remove digits and Latin letters to focus on Chinese tokens
  text_cleaned <- gsub("[0-9]", " ", text)
  text_cleaned <- gsub("[a-zA-Z]", " ", text_cleaned)
  tokens <- segment(text_cleaned, worker())
  if (length(tokens) == 0) next
  temp_df <- data.frame(id = grid_centers$id[i], word = tokens)
  df_segmented <- rbind(df_segmented, temp_df)
}

# --- Tidy + TF-IDF features ---
tidy_df <- df_segmented |>
  # Remove stopwords (ensure the referenced column exists in your data)
  filter(!(text %in% stopwords))|>
  # Remove empty tokens
  filter(word != '')|>
  # Count token frequency per document
  count(id, word) |>
  # Drop infrequent tokens (n ≤ 2)
  filter(n > 2) |>
  # Remove single-character tokens
  filter(nchar(word) != 1) |>
  # Compute TF–IDF features
  bind_tf_idf(term = 'word',
              document = 'id',
              n = 'n') |>
  filter(!is.na(tf_idf))

# Construct a sparse document–term matrix with TF–IDF values
df_matrix_result <- tidy_df |>
  pivot_wider(
    id_cols = 'id',
    names_from = 'word',
    values_from = 'tf_idf',
    values_fill = 0,
    values_fn = sum
  )

# --- Save outputs for downstream models ---
save(df_segmented, file = "generated_data/cleaned_df.RData")
save(df_matrix_result, file = "generated_data/df_matrix_result.RData")

# Load the cleaned datasets (optional)
# load("generated_data/cleaned_df.RData")
# load("generated_data/df_matrix_result.RData")

# Keep only articles that have TF–IDF rows
final_articles <- grid_centers[grid_centers$id %in% df_matrix_result$id, ]

# Character count per article (proxy for length)
final_articles$words <- sapply(final_articles$text, nchar)

save(final_articles, file = "generated_data/final_GC_articles.RData")

# Distribution of articles by official GC account (for Appendix B)
final_GC_by_account <- final_articles |>
  group_by(account) |>
  count() |>
  print(n = 100)

# Appendix B excerpt: export account-level counts
write_xlsx(final_GC_by_account, "generated_data/final_GC_by_account.xlsx")

# Create manual-coding sample: randomly select 30% of articles
# (labels: e.g., S=Service, C=Control, O=Others)
set.seed(30605)
final_articles$topic <- NA
sample <- final_articles[sample(nrow(final_articles), ceiling(0.30 * nrow(final_articles))), ]
write_xlsx(sample, "generated_data/topic_coding.xlsx")

# Export sampled article texts as .txt files for manual annotators
options(encoding = "UTF-8")
for (i in 1:nrow(sample)) {
  text_content <- as.character(sample[i, "text"])
  file_name <- paste0(sample[i, "id"])
  txt_file_path <- file.path("generated_data/for_coding", paste0(file_name, ".txt"))
  cat(text_content, file = txt_file_path, sep = "\n", fileEncoding = "UTF-8")
}

# Load coded sample (human labels)
coding_result <- read_excel("data/coded.xlsx")

# Merge manual labels back to the TF–IDF matrix by id
df_matrix_result$topic <- NA
match_indices <- match(df_matrix_result$id, coding_result$id)
df_matrix_result$topic <- coding_result$topic[match_indices]

# Split labeled vs. unlabeled sets
coded_sample <- df_matrix_result[!is.na(df_matrix_result$topic), ]
predict_sample <- df_matrix_result[is.na(df_matrix_result$topic), ]

# --- Random Forest classification ---
set.seed(30605)
# 10-fold CV
train_control <- trainControl(method = "cv", number = 10)

# Parallel training
no_cores <- detectCores() - 1
registerDoParallel(no_cores)

# Baseline RF (all features except id)
set.seed(30605)
rf_model1 <- train(topic ~ .-id,
                   data = coded_sample,
                   method = "rf",
                   trControl = train_control)

stopImplicitCluster()

print(rf_model1)
save(rf_model1, file = "generated_data/rf_model1.RData")
# load("generated_data/rf_model1.RData")

# Variable importance (overall) and top-50 features
importance <- varImp(rf_model1, scale = FALSE)
print(importance)

importance_df <- data.frame(Feature = rownames(importance$importance),
                            Importance = importance$importance$Overall)
importance_df <- importance_df[order(-importance_df$Importance), ]
print(importance_df)

important_features <- importance_df$Feature[1:50]
print(important_features)

# Switch locale for plotting Chinese labels if needed
Sys.setlocale("LC_ALL", "zh_CN.UTF-8")

# RF with top-50 features only (feature-restricted model)
important_formula <- as.formula(paste("topic ~", paste(important_features, collapse = " + ")))

# Hyperparameter tuning for mtry in [11, 50]
tune_grid <- expand.grid(.mtry = seq(11, 50, by = 1))

# Parallel training
no_cores <- detectCores() - 1
registerDoParallel(no_cores)

set.seed(30605)
rf_model2 <- train(important_formula,
                   data = coded_sample,
                   method = "rf",
                   trControl = train_control,
                   tuneGrid = tune_grid)

stopImplicitCluster()

print(rf_model2)
save(rf_model2, file = "generated_data/rf_model2.RData")
# Restore locale
Sys.setlocale("LC_ALL", "en_US.UTF-8")

# load("generated_data/rf_model2.RData")

# Cross-validated in-sample performance: confusion matrix & F1 by class
predictions <- predict(rf_model2, newdata = coded_sample)
all_levels <- union(levels(predictions), levels(coded_sample$topic))
predictions <- factor(predictions, levels = all_levels)
coded_sample$topic <- factor(coded_sample$topic, levels = all_levels)
confusion_matrix <- confusionMatrix(predictions, coded_sample$topic)
print(confusion_matrix)

# Per-class F1 scores (S=Service; C=Control; O=Others)
f1_S <- F1_Score(y_pred = predictions, y_true = coded_sample$topic, positive = "S")
f1_C <- F1_Score(y_pred = predictions, y_true = coded_sample$topic, positive = "C")
f1_O <- F1_Score(y_pred = predictions, y_true = coded_sample$topic, positive = "O")
print(paste("F1 Score for S:", f1_S))
print(paste("F1 Score for C:", f1_C))
print(paste("F1 Score for O:", f1_O))

# --- Figure A6: Top important words (percentage of total importance) ---
importance_words <- varImp(rf_model2, scale = FALSE)
importance_df <- importance_words $importance
importance_df$Variables <- rownames(importance_df)
rownames(importance_df) <- NULL
total_importance <- sum(importance_df$Overall)
importance_df$Percentage <- (importance_df$Overall / total_importance) * 100
top_importance_df <- head(importance_df[order(-importance_df$Percentage), ], 20)

top_importance_df <- top_importance_df |>
  rename(importance = Percentage,
         Chinese_word = Variables) |>
  # Manual bilingual labels for the 20 most important tokens
  mutate(English_word = c("Prevention and Control", "Pandemic", "Work", "Grid System", "Personnel",
                          "Center",  "Grid-Based", "Propaganda", "Residents", "Community", "Elderly",
                          "The Masses", "We", "The Public", "Pandemic Prevention",
                          "They", "Masks", "Management", "Governance", "Villagers"))

top_importance_df <- top_importance_df |>
  mutate(rank = rank(-importance, ties.method = "first"))

# FIGURE A6. Top 20 Important Words in RF Model (Total Importance: 100%)
figureA6 <- ggplot(top_importance_df, aes(x = reorder(Chinese_word, -rank), y = importance)) +
  geom_point(stat = 'identity') +
  geom_text(aes(label = Chinese_word), hjust = 0, size = 5, nudge_x = 0,  nudge_y = 0.4) +
  geom_text(aes(label = English_word), hjust = 0, size = 6, nudge_x = 0.04, nudge_y = 1.42) +
  scale_y_continuous(name = "Word Importance",
                     labels = scales::percent_format(scale = 1),
                     breaks = seq(0, 22, by = 5),
                     limits = c(0, 22)) +
  coord_flip() +
  theme_light(base_size = 16) +
  labs(title = "", x = "Words") +
  theme_bw() +
  theme(axis.title.x = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        plot.title = element_text(hjust = 0.5, face = "bold",size = 18),
        axis.text.x = element_text(size = 18),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_rect(fill = "white", colour = NA),
        panel.border = element_rect(colour = "black", fill = NA, size = 1))

figureA6
ggsave(filename = "plots/FIGUREA6.png", plot = figureA6, width = 14.4, height = 7.2, units = "in", dpi = 200)

# --- Predict the remaining (unlabeled) articles with class probabilities ---
predict_sample$topic <- predict(rf_model2, newdata = predict_sample , type="prob")
predict_sample$C <- predict_sample$topic$C
predict_sample$O <- predict_sample$topic$O
predict_sample$S <- predict_sample$topic$S

predict_result <- dplyr::select(predict_sample, id, C, O, S)
coded_result <- dplyr::select(coded_sample, id, topic)

# Convert gold labels to one-hot (C/O/S) for merging with predicted probs
coded_result$C <- ifelse(coded_result$topic == "C", 1, 0)
coded_result$O <- ifelse(coded_result$topic == "O", 1, 0)
coded_result$S <- ifelse(coded_result$topic == "S", 1, 0)
coded_result <- dplyr::select(coded_result, id, C, O, S)

# Combine coded and predicted samples (used as dependent variables later)
classified_samples <- merge(coded_result, predict_result, all = TRUE)
save(classified_samples, file = "generated_data/classified_samples.RData")

# --- Describe dependent variables over time (Figure 2) ---
DV_description <- classified_samples %>%
  left_join(final_articles, by = "id") %>%
  dplyr::select(date, C, S, O, id) %>%
  group_by(date) %>%
  summarise(daily_C = mean(C, na.rm = TRUE),
            daily_S = mean(S, na.rm = TRUE),
            daily_O = mean(O, na.rm = TRUE)) %>%
  pivot_longer(
    cols = starts_with("daily_"),
    names_to = "functions",
    values_to = "average",
    names_prefix = "daily_"
  )

# FIGURE 2. Time-series of functional proportions with Wuhan lockdown marker
figure2 <- ggplot(DV_description, aes(x = date, y = average, color = functions, linetype = functions)) +
  geom_smooth(se = FALSE, method = "loess", size = 2) +
  labs(title = "", x = "Date", y = "Proportion of Functions") +
  theme_bw() +
  theme(plot.title = element_text(hjust = 0.5, size = 20, face = "bold"),
        axis.title.x = element_blank(),
        axis.title.y = element_text(size = 16),
        axis.text.x = element_text(size = 16),
        axis.text.y = element_text(size = 16),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_rect(fill = "white", colour = NA),
        legend.position = c(0.1, 0.83),
        legend.title = element_blank(),
        legend.text = element_text(size = 20, lineheight = 2),
        legend.key.size = unit(1.5, "cm"),
        legend.key.width = unit(2, "cm"),
        legend.spacing.x = unit(1, "cm"),
        legend.spacing.y = unit(0.5, "cm"),
        panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_y_continuous(breaks = c(0.2, 0.4, 0.6, 0.8), limits = c(0, 0.9)) +
  scale_color_manual(values = c("C" = "black", "S" = "black", "O" = "black"),
                     labels = c("C" = "Control", "S" = "Service", "O" = "Others")) +
  scale_linetype_manual(values = c("C" = "18", "S" = "dashed", "O" = "solid"),
                        labels = c("C" = "Control", "S" = "Service", "O" = "Others")) +
  geom_vline(xintercept = as.Date("2020-01-23"), linetype = "dashed", color = "#0033FF", linewidth = 1.5) +
  annotate("text", x = as.Date("2020-01-08"), y = 0.78, label = "Wuhan Lockdown", size = 6, color = "black", hjust = 1) +
  annotate("segment", x = as.Date("2020-01-10"), xend = as.Date("2020-01-21"), y = 0.78, yend = 0.78,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm")))

figure2
ggsave(filename = "plots/FIGURE2.png", plot = figure2, width = 12.4, height = 7.2, units = "in", dpi = 800)

# --- Independent variables (accounts & geo features) ---
account_data <- read_xlsx('data/account_data.xlsx')
summary(account_data)

# Proxy for local development/intensity: Nightlight means within account-specific bounding boxes
account_data <- account_data %>%
  mutate(
    # Bounding box size scales with admin area size (heuristic rules)
    lon_west = case_when(
      area < 10 ~ longitude - 0.01,
      area >= 10 & area <= 100 ~ longitude - 0.03,
      area > 100 & area <= 1000 ~ longitude - 0.1,
      area > 1000 ~ longitude - 0.2,
      TRUE ~ longitude
    ),
    lon_east = case_when(
      area < 10 ~ longitude + 0.01,
      area >= 10 & area <= 100 ~ longitude + 0.03,
      area > 100 & area <= 1000 ~ longitude + 0.1,
      area > 1000 ~ longitude + 0.2,
      TRUE ~ longitude
    ),
    lat_south = case_when(
      area < 10 ~ latitude - 0.01,
      area >= 10 & area <= 100 ~ latitude - 0.03,
      area > 100 & area <= 1000 ~ latitude - 0.1,
      area > 1000 ~ latitude - 0.2,
      TRUE ~ latitude
    ),
    lat_north = case_when(
      area < 10 ~ latitude + 0.01,
      area >= 10 & area <= 100 ~ latitude + 0.03,
      area > 100 & area <= 1000 ~ latitude + 0.1,
      area > 1000 ~ latitude + 0.2,
      TRUE ~ latitude
    )
  )

# Load VIIRS nightlight raster (2020 average)
hiResLights = raster(
  'data/nightlight2020/F16_20200101_20201231.global.stable_lights.avg_vis.tif')
coordinates <- account_data[,9:12]

# Compute mean nightlight per bounding box
nightlight_means <- list()
for (i in 1:nrow(coordinates)) {
  e <- extent(as.numeric(coordinates[i,1]),as.numeric(coordinates[i,2]),
              as.numeric(coordinates[i,3]),as.numeric(coordinates[i,4]))
  r = crop(hiResLights,e)
  values(r)
  temp_df <- data.frame(values(r))
  mean_value<- mean(temp_df$values.r.)
  nightlight_means <- c(nightlight_means, mean_value)
}

# Transformations for controls
account_data$nightlight2020 <- nightlight_means
account_data$nightlight2020 <- as.numeric(account_data$nightlight2020)
account_data$loggdppc <- log(account_data$gdp_pc)
account_data$logpop <- log(account_data$population)

summary(classified_samples)
summary(final_articles)

# --- Data prep for Dirichlet regression ---
local_grid <- final_articles |>
  subset(select = c(account, id, interval,words,GC))

# Merge DV (classified functions) to article-level metadata
local_grid <- merge(local_grid, classified_samples, by="id")

# Join account-level covariates
local_grid <- local_grid %>%
  left_join(account_data[, c('account', "city_level", 'gdp_pc', 'population', 'hubei', 'nightlight2020','loggdppc','logpop')], by = 'account')

# Keep analysis columns (composition + covariates + IDs)
local_grid <- local_grid[, c("O", "C", "S", "interval", "city_level", "nightlight2020", "loggdppc", "logpop","hubei","words","account","GC","id")]

# Controls: log(words) as text-length proxy
local_grid$logwords <- log(local_grid$words)

# Post-outbreak indicator based on interval
local_grid <- local_grid |>
  mutate(after_pandemic = if_else(interval > 0, 1, if_else(interval <= 0, 0, NA_integer_)))

# Summary checks
summary(local_grid)
sd(local_grid$O); sd(local_grid$C); sd(local_grid$S)
sd(local_grid$interval); sd(local_grid$logwords)

summary(account_data)
sd(account_data$nightlight2020); sd(account_data$loggdppc); sd(account_data$logpop)
sd(account_data$hubei); sd(local_grid$after_pandemic)
median(local_grid$after_pandemic)

# Composition matrix for DirichletReg
local_grid$topic <- cbind(local_grid$O, local_grid$C, local_grid$S)
topic_drd <- DR_data(local_grid$topic)
local_grid$GC <- as.factor(local_grid$GC)
local_grid$account <- as.factor(local_grid$account)

# Descriptive means/medians
mean(local_grid$O); mean(local_grid$C); mean(local_grid$S)
mean(local_grid$interval); mean(local_grid$after_pandemic)
mean(account_data$hubei); mean(account_data$nightlight2020)
mean(account_data$loggdppc); mean(account_data$logpop); mean(local_grid$logwords)
median(local_grid$O); median(local_grid$C); median(local_grid$S)

save(local_grid, file = "generated_data/local_grid.RData")

# --- Dirichlet regression models ---
set.seed(30605)
# Model 1: Baseline with linear time trend
DRmodel1 <- DirichReg(topic_drd ~ interval + nightlight2020 + loggdppc + logpop + hubei + logwords, data = local_grid)

# Model 2: Add post-pandemic indicator
DRmodel2 <- DirichReg(topic_drd ~ interval + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords, data = local_grid)

# Model 3: Add quadratic time term
DRmodel3 <- DirichReg(topic_drd ~ interval + I(interval^2)  + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords, data = local_grid)

options(digits=6)
summary(DRmodel1)
summary(DRmodel2)
summary(DRmodel3)

# Model comparison (lower AIC/BIC indicates better fit)
options(digits= 8)
aic_values <- c(AIC(DRmodel1), AIC(DRmodel2), AIC(DRmodel3))
names(aic_values) <- c("DRmodel1", "DRmodel2", "DRmodel3")
aic_values

bic_values <- c(BIC(DRmodel1), BIC(DRmodel2), BIC(DRmodel3))
names(bic_values) <- c("DRmodel1", "DRmodel2", "DRmodel3")
bic_values

log_likelihood_values <- c(logLik(DRmodel1), logLik(DRmodel2), logLik(DRmodel3))
names(log_likelihood_values) <- c("DRmodel1", "DRmodel2", "DRmodel3")
log_likelihood_values

# --- Predicted compositions over time (Dirichlet, outside vs. within Hubei) ---
# Baseline covariates set at medians; contrast hubei=0 vs hubei=1; draw CIs via parametric simulation

# Outside Hubei
new_data <- data.frame(
  interval = min(local_grid$interval):max(local_grid$interval),
  nightlight2020 = median(account_data$nightlight2020, na.rm = TRUE),
  loggdppc = median(account_data$loggdppc, na.rm = TRUE),
  logpop = median(account_data$logpop, na.rm = TRUE),
  hubei = 0,
  logwords = median(local_grid$logwords),
  after_pandemic = ifelse(min(local_grid$interval):max(local_grid$interval) > 0, 1, 0)
)
new_data$interval2 <- new_data$interval^2

predicted_probs <- predict(DRmodel3, newdata = new_data, type = "probs")
predicted_probs_df <- as.data.frame(predicted_probs)
predicted_probs_df$interval = new_data$interval
coefs <- coef(DRmodel3)
vcov_matrix <- vcov(DRmodel3)
coefs_vector <- c(coefs$v1, coefs$v2, coefs$v3)
n_sims <- 2000
set.seed(30605)
simulated_coefs <- mvrnorm(n = n_sims, mu = coefs_vector, Sigma = vcov_matrix)
simulated_probs_matrix <- matrix(NA, nrow = n_sims, ncol = length(predicted_probs_df$interval) * 3)

softmax <- function(eta) {
  exp_eta <- exp(eta)
  return(exp_eta / rowSums(exp_eta))
}

for(i in 1:n_sims) {
  sim_coef <- matrix(simulated_coefs[i, ], nrow = 3, byrow = TRUE)
  model_matrix <- model.matrix(~ 1+ interval+ I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords, new_data)
  eta <- model_matrix %*% t(sim_coef)
  sim_probs <- softmax(eta)
  simulated_probs_matrix[i, ] <- as.vector(t(sim_probs))
}

# 95% CI by simulation quantiles
prob_cis <- array(dim = c(2, length(predicted_probs_df$interval), 3))
for (j in 1:3) {
  for (k in 1:length(predicted_probs_df$interval)) {
    prob_cis[ , k, j] <- quantile(simulated_probs_matrix[, (k-1)*3 + j], probs = c(0.025, 0.975))
  }
}

for (j in 1:3) {
  predicted_probs_df[[paste0("lower_CI_", j)]] <- prob_cis[1, , j]
  predicted_probs_df[[paste0("upper_CI_", j)]] <- prob_cis[2, , j]
}

predicted_probs_df_less_than_zero <- predicted_probs_df[predicted_probs_df$interval <= 0, ]
predicted_probs_df_greater_than_zero <- predicted_probs_df[predicted_probs_df$interval > 0, ]

# FIGURE 3 (panel, black-and-white): Outside Hubei
DR_P1 <- ggplot() +
  geom_ribbon(data = predicted_probs_df_less_than_zero, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_line(data = predicted_probs_df_less_than_zero, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_less_than_zero, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_less_than_zero, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", linewidth = 2) +
  geom_vline(xintercept = 0, color = "#0033FF", linewidth = 1.5, linetype = "dashed") +
  scale_linetype_manual(
    values = c("Category 1" = "solid", "Category 2" = "dotted", "Category 3" = "dashed"),
    labels = c("Others", "Social Control", "Providing Services")
  ) +
  labs(
    title = "",
    x = "Day",
    y = "Functions of the Grid outside of Hubei Province (Proportion)",
    linetype = "Functions (with 95% C.I.)"
  ) +
  theme_bw() +
  theme(
    legend.position = c(0.23,0.85),
    legend.text = element_text(size = 18),
    text = element_text(size = 18),
    plot.title = element_text(size = 18),
    axis.title = element_text(size = 15.5),
    axis.text = element_text(size = 18),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.background = element_blank(),
    legend.key.size = unit(1.2, "cm"),
    legend.key.width = unit(2, "cm"),
    legend.spacing.x = unit(1, "cm"),
    legend.spacing.y = unit(0.4, "cm"),
    panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_x_continuous(
    breaks = c(-50, -25, 0, 25, 50),
    limits = c(-55, 55)
  )+
  scale_y_continuous(
    breaks = c(0.2, 0.4, 0.6, 0.8),
    limits = c(0, 0.85)
  )+
  annotate("text", x = 52, y = 0.7, label = "Wuhan Lockdown", hjust = 1, size = 7, color = "black") +
  annotate("segment", x = 14, xend = 1, y = 0.7, yend = 0.7,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 55, y = 0.85, label = "Outside Hubei", hjust = 1, size = 7, color = "black")

print(DR_P1)

# Within Hubei (same baseline covariates, set hubei=1)
new_data2 <- data.frame(
  interval = min(local_grid$interval):max(local_grid$interval),
  nightlight2020 = median(account_data$nightlight2020, na.rm = TRUE),
  loggdppc = median(account_data$loggdppc, na.rm = TRUE),
  logpop = median(account_data$logpop, na.rm = TRUE),
  hubei = 1,
  logwords = median(local_grid$logwords),
  after_pandemic = ifelse(min(local_grid$interval):max(local_grid$interval) > 0, 1, 0)
)
new_data2$interval2 <- new_data2$interval^2

predicted_probs2 <- predict(DRmodel3, newdata = new_data2, type = "probs")
predicted_probs_df2 <- as.data.frame(predicted_probs2)
predicted_probs_df2$interval = new_data2$interval
coefs2 <- coef(DRmodel3)
vcov_matrix2 <- vcov(DRmodel3)
coefs_vector2 <- c(coefs2$v1, coefs2$v2, coefs2$v3)
n_sims2 <- 2000
set.seed(30605)
simulated_coefs2 <- mvrnorm(n = n_sims2, mu = coefs_vector2, Sigma = vcov_matrix2)
simulated_probs_matrix2 <- matrix(NA, nrow = n_sims2, ncol = length(predicted_probs_df2$interval) * 3)

softmax2 <- function(eta2) {
  exp_eta2 <- exp(eta2)
  return(exp_eta2 / rowSums(exp_eta2))
}

for(i in 1:n_sims2) {
  sim_coef2 <- matrix(simulated_coefs2[i, ], nrow = 3, byrow = TRUE)
  model_matrix2 <- model.matrix(~ 1+ interval + I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords, new_data2)
  eta2 <- model_matrix2 %*% t(sim_coef2)
  sim_probs2 <- softmax(eta2)
  simulated_probs_matrix2[i, ] <- as.vector(t(sim_probs2))
}

prob_cis2 <- array(dim = c(2, length(predicted_probs_df2$interval), 3))
for (j in 1:3) {
  for (k in 1:length(predicted_probs_df2$interval)) {
    prob_cis[ , k, j] <- quantile(simulated_probs_matrix2[, (k-1)*3 + j], probs = c(0.025, 0.975))
  }
}

for (j in 1:3) {
  predicted_probs_df2[[paste0("lower_CI_", j)]] <- prob_cis[1, , j]
  predicted_probs_df2[[paste0("upper_CI_", j)]] <- prob_cis[2, , j]
}

predicted_probs_df_less_than_zero2 <- predicted_probs_df2[predicted_probs_df2$interval <= 0, ]
predicted_probs_df_greater_than_zero2 <- predicted_probs_df2[predicted_probs_df2$interval > 0, ]

# FIGURE 3 (panel): Within Hubei
DR_P2 <- ggplot() +
  geom_ribbon(data = predicted_probs_df_less_than_zero2, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero2, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero2, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero2, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero2, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero2, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_line(data = predicted_probs_df_less_than_zero2, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero2, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_less_than_zero2, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero2, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_less_than_zero2, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", linewidth = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero2, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", linewidth = 2) +
  geom_vline(xintercept = 0, color = "#0033FF", linewidth = 1.5, linetype = "dashed") +
  scale_linetype_manual(
    values = c("Category 1" = "solid", "Category 2" = "dotted", "Category 3" = "dashed"),
    labels = c("Others", "Social Control", "Providing Services")
  ) +
  labs(
    title = "",
    x = "Day",
    y = "Functions of the Grid within the Hubei Province (Proportion)",
    linetype = "Functions (with 95% C.I.)"
  ) +
  theme_bw() +
  theme(
    legend.position = c(0.23,0.85),
    legend.text = element_text(size = 18),
    text = element_text(size = 18),
    plot.title = element_text(size = 18),
    axis.title = element_text(size = 15.5),
    axis.text = element_text(size = 18),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.background = element_blank(),
    legend.key.size = unit(1.2, "cm"),
    legend.key.width = unit(2, "cm"),
    legend.spacing.x = unit(1, "cm"),
    legend.spacing.y = unit(0.4, "cm"),
    panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_x_continuous(
    breaks = c(-50, -25, 0, 25, 50),
    limits = c(-55, 55)
  )+
  scale_y_continuous(
    breaks = c(0.2, 0.4, 0.6, 0.8),
    limits = c(0, 0.85)
  )+
  annotate("text", x = 52, y = 0.7, label = "Wuhan Lockdown", hjust = 1, size = 7, color = "black") +
  annotate("segment", x = 14, xend = 1, y = 0.7, yend = 0.7,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 55, y = 0.85, label = "Hubei Province", hjust = 1, size = 7, color = "black")

print(DR_P2)

# Combine panels into FIGURE 3
title_text3 <- "FIGURE 3. The Functional Change of the Grid before and after COVID-19 Outbreak Within and Outside the Hubei Province"
title_grob3 <- grobTree(
  textGrob(title_text3,
           gp = gpar(fontsize = 20, fontface = "bold"), y = -0.6),
  linesGrob(gp = gpar(col = "black", lwd = 2), y = unit(-2.4, "npc"),
            x = unit(c(0.05, 0.95), "npc"))
)
figure3 <- grid.arrange(
  DR_P1,
  DR_P2,
  ncol = 2
)
ggsave(filename = "plots/FIGURE3.png", plot = figure3, width = 16.8, height = 7.5, units = "in", dpi = 600)

# --- Dirichlet model diagnostics (Appendix H) ---
residuals <- residuals(DRmodel3, type = "raw")
fitted_values <- fitted(DRmodel3)

# Scatter + loess for each function vs time
# Others
ggplot(local_grid, aes(x = interval, y = O)) +
  geom_point(size = 0.8) +
  geom_smooth(method = "loess") +
  theme_minimal() +
  xlab("Day") +
  ylab("Others")+
  ggtitle("Functional Change of the Grid (Other Functions)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

# Social Control
ggplot(local_grid, aes(x = interval, y = C)) +
  geom_point(size = 0.8) +
  geom_smooth(method = "loess") +
  theme_minimal() +
  xlab("Day") +
  ylab("Others")+
  ggtitle("Functional Change of the Grid (Social Control)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

# Providing Service
ggplot(local_grid, aes(x = interval, y = S)) +
  geom_point(size = 0.8) +
  geom_smooth(method = "loess") +
  theme_minimal() +
  xlab("Day") +
  ylab("Others")+
  ggtitle("Functional Change of the Grid (Providing Service)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

# Save references to individual ggplot objects (used in patchwork grid)
plot_others_functional <- ggplot(local_grid, aes(x = interval, y = O)) +
  geom_point(size = 0.8) +
  geom_smooth(method = "loess") +
  theme_minimal() +
  xlab("Day") +
  ylab("Others") +
  ggtitle("Functional Change of the Grid (Other Functions)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

plot_social_control_functional <- ggplot(local_grid, aes(x = interval, y = C)) +
  geom_point(size = 0.8) +
  geom_smooth(method = "loess") +
  theme_minimal() +
  xlab("Day") +
  ylab("Social Control") +
  ggtitle("Functional Change of the Grid (Social Control)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

plot_providing_service_functional <- ggplot(local_grid, aes(x = interval, y = S)) +
  geom_point(size = 0.8) +
  geom_smooth(method = "loess") +
  theme_minimal() +
  xlab("Day") +
  ylab("Providing Service") +
  ggtitle("Functional Change of the Grid (Providing Service)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

# Residual plots (raw residuals) for each composition component
plot_others_residuals <- ggplot(local_grid, aes(x = interval, y = residuals[, 1])) +
  geom_point(size = 0.8) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "blue", linewidth  = 1.5) +
  theme_minimal() +
  xlab("Day") +
  ylab("Residuals of Others") +
  ggtitle("Residuals (Others)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

plot_social_control_residuals <- ggplot(local_grid, aes(x = interval, y = residuals[, 2])) +
  geom_point(size = 0.8) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "blue", linewidth  = 1.5) +
  theme_minimal() +
  xlab("Day") +
  ylab("Residuals of Social Control") +
  ggtitle("Residuals (Social Control)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

plot_providing_service_residuals <- ggplot(local_grid, aes(x = interval, y = residuals[, 3])) +
  geom_point(size = 0.8) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "blue", linewidth  = 1.5) +
  theme_minimal() +
  xlab("Day") +
  ylab("Residuals of Providing Service") +
  ggtitle("Residuals (Providing Service)")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

# Q-Q plots for residual normality checks (visual diagnostic)
plot_others_qq <- ggplot(data.frame(residuals = residuals[, 1]), aes(sample = residuals)) +
  stat_qq_line() +
  stat_qq_point(size = 0.8) +
  ggtitle("Normal Q-Q Plot of 'Others' Residuals")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

plot_social_control_qq <- ggplot(data.frame(residuals = residuals[, 2]), aes(sample = residuals)) +
  stat_qq_line() +
  stat_qq_point(size = 0.8) +
  ggtitle("Normal Q-Q Plot of 'Social Control' Residuals")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

plot_providing_service_qq <- ggplot(data.frame(residuals = residuals[, 3]), aes(sample = residuals)) +
  stat_qq_line() +
  stat_qq_point(size = 0.8) +
  ggtitle("Normal Q-Q Plot of 'Providing Service' Residuals")+
  theme_bw() +
  theme(plot.title = element_text(size=9, lineheight=3),
        axis.text.x = element_text(size = 10),
        axis.title.x = element_text(size =10),
        axis.title.y = element_text(size =10),
        axis.text.y = element_text(size = 10),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_blank())

# 3×3 diagnostic grid (functional trends, residuals, Q-Q)
combined_plots <- plot_others_functional / plot_social_control_functional / plot_providing_service_functional /
  plot_others_residuals / plot_social_control_residuals / plot_providing_service_residuals /
  plot_others_qq / plot_social_control_qq / plot_providing_service_qq

Dirichelet_diag <- combined_plots +
  plot_layout(nrow = 3, ncol = 3)

Dirichelet_diag <- Dirichelet_diag +
  plot_annotation(
    title = "",   # FIGURE A6-1. Dirichlet Regression Diagnosis
    theme = theme(plot.title = element_text(hjust = 0.5, face = "bold"))
  )

ggsave(filename = "plots/FIGUREA7.png", plot = Dirichelet_diag, width = 10.5, height = 8.4, units = "in", dpi = 600)

# --- Robustness checks for Dirichlet model ---
# RB1: Exclude five most prolific GC accounts (to avoid dominance by prolific posters)
account_counts <- table(local_grid$account)
sorted_accounts <- names(sort(account_counts, decreasing = TRUE))
top_accounts_to_remove <- sorted_accounts[1:5]
data_Robust1 <- local_grid[!(local_grid$account %in% top_accounts_to_remove), ]

data_Robust1$topic <- cbind(data_Robust1$O, data_Robust1$C, data_Robust1$S)
topic_drd_R1 <- DR_data(data_Robust1$topic)
DRmodel_RB1 <- DirichReg(topic_drd_R1 ~ interval + I(interval^2)  + nightlight2020 + loggdppc + logpop + hubei+ logwords + after_pandemic , data = data_Robust1)
summary(DRmodel_RB1)

# RB2: City-level grids only
data_Robust2 <- local_grid |>
  filter(city_level == "1")
data_Robust2$topic <- cbind(data_Robust2$O, data_Robust2$C, data_Robust2$S)
topic_drd_R2 <- DR_data(data_Robust2$topic)
DRmodel_RB2 <- DirichReg(topic_drd_R2 ~ interval + I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + logwords + after_pandemic, data = data_Robust2)
summary(DRmodel_RB2)

# RB3: County-level and below
data_Robust3 <- local_grid |>
  filter(city_level != "1")
data_Robust3$topic <- cbind(data_Robust3$O, data_Robust3$C, data_Robust3$S)
topic_drd_R3 <- DR_data(data_Robust3$topic)
DRmodel_RB3 <- DirichReg(topic_drd_R3 ~ interval + I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + logwords + after_pandemic, data = data_Robust3)
summary(DRmodel_RB3)

# Compare AIC/BIC across robustness models
aic_values_RB <- c(AIC(DRmodel_RB1), AIC(DRmodel_RB2), AIC(DRmodel_RB3))
names(aic_values_RB) <- c("DRmodel_RB1", "DRmodel_RB2", "DRmodel_RB3")
aic_values_RB

bic_values_RB <- c(BIC(DRmodel_RB1), BIC(DRmodel_RB2), BIC(DRmodel_RB3))
names(bic_values_RB) <- c("DRmodel_RB1", "DRmodel_RB2", "DRmodel_RB3")
bic_values_RB

log_likelihood_values <- c(logLik(DRmodel_RB1), logLik(DRmodel_RB2), logLik(DRmodel_RB3))
names(log_likelihood_values) <- c("DRmodel_RB1", "DRmodel_RB2", "DRmodel_RB3")
log_likelihood_values

options(digits = 6)
summary(DRmodel_RB1)
summary(DRmodel_RB2)
summary(DRmodel_RB3)

# --- Multilevel Beta regressions as additional robustness ---
# Constrain proportions to (0,1) open interval for beta links
local_grid <- local_grid |>
  mutate(O = pmax(pmin(O, 0.999), 0.001))
local_grid <- local_grid |>
  mutate(C = pmax(pmin(C, 0.999), 0.001))
local_grid <- local_grid |>
  mutate(S = pmax(pmin(S, 0.999), 0.001))

local_grid$hubei <- as.factor(local_grid$hubei)
local_grid$GC <- factor(local_grid$GC, levels = unique(local_grid$GC))

# Random-intercept (by GC) beta models for each function
beta_O <- glmmTMB(
  O ~ interval +  I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords + (1 | GC),
  data = local_grid,
  family = beta_family(link = "logit")
)
summary(beta_O)

# Cluster-robust SE function (clustered by GC)
clustered_se <- function(model, dataset, num_coefficients) {
  residuals <- residuals(model, type = "pearson")
  model_matrix <- model.matrix(model)
  estfun_matrix <- model_matrix * residuals
  cluster_var <- dataset$GC
  meat_matrix <- matrix(0, ncol = ncol(model_matrix), nrow = ncol(model_matrix))
  
  for(cl in unique(cluster_var)) {
    cl_rows <- which(cluster_var == cl)
    sum_estfun <- colSums(estfun_matrix[cl_rows, , drop = FALSE])
    meat_matrix <- meat_matrix + sum_estfun %*% t(sum_estfun)
  }
  
  bread_matrix <- solve(crossprod(model_matrix))
  vcov_sandwich <- bread_matrix %*% meat_matrix %*% bread_matrix
  
  fixed_coefs <- model[["fit"]][["parfull"]]
  fixed_coefs <- fixed_coefs[1:num_coefficients]
  std_errors <- sqrt(diag(vcov_sandwich)[1:num_coefficients])
  z_values <- fixed_coefs / std_errors
  p_values <- 2 * pnorm(-abs(z_values))
  
  clustered_SE <- data.frame(
    Coefficients = fixed_coefs,
    Std.Errors = std_errors,
    Z.values = z_values,
    P.values = p_values
  )
  
  return(clustered_SE)
}

# Clustered SE for O-model
clustered_se(beta_O, local_grid, 9)

# ICC (intraclass correlation) for O-model
var_random_effects_beta_O <- VarCorr(beta_O)$cond$GC[1]
mu_i <- predict(beta_O, type = "response")
phi <- sigma(beta_O)
within_group_variance <- mu_i * (1 - mu_i) / (1 + phi)
mean_within_group_variance <- mean(within_group_variance)
icc_beta_O <- var_random_effects_beta_O / (var_random_effects_beta_O + mean_within_group_variance)
icc_beta_O

# C-model
beta_C <- glmmTMB(
  C ~ interval +  I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords + (1 | GC),
  data = local_grid,
  family = beta_family(link = "logit")
)
summary(beta_C)
clustered_se(beta_C, local_grid, 9)

var_random_effects_beta_C <- VarCorr(beta_C)$cond$GC[1]
mu_i <- predict(beta_C, type = "response")
phi <- sigma(beta_C)
within_group_variance <- mu_i * (1 - mu_i) / (1 + phi)
mean_within_group_variance <- mean(within_group_variance)
icc_beta_C <- var_random_effects_beta_C / (var_random_effects_beta_C + mean_within_group_variance)
icc_beta_C

# S-model
beta_S <- glmmTMB(
  S ~ interval +  I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + after_pandemic + logwords + (1 | GC),
  data = local_grid,
  family = beta_family(link = "logit")
)
summary(beta_S)
clustered_se(beta_S, local_grid, 9)

var_random_effects_beta_S <- VarCorr(beta_S)$cond$GC[1]
mu_i <- predict(beta_S, type = "response")
phi <- sigma(beta_S)
within_group_variance <- mu_i * (1 - mu_i) / (1 + phi)
mean_within_group_variance <- mean(within_group_variance)
icc_beta_S <- var_random_effects_beta_S / (var_random_effects_beta_S + mean_within_group_variance)
icc_beta_S

# --- Robustness plots (Dirichlet RB1–RB3) ---
# RB1 (exclude top 5 prolific GCs)
new_data3 <- data.frame(
  interval = min(data_Robust1$interval):max(data_Robust1$interval),
  nightlight2020 = median(data_Robust1$nightlight2020, na.rm = TRUE),
  loggdppc = median(data_Robust1$loggdppc, na.rm = TRUE),
  logpop = median(data_Robust1$logpop, na.rm = TRUE),
  hubei = 0,
  logwords = median(data_Robust1$logwords),
  after_pandemic = ifelse(min(data_Robust1$interval):max(data_Robust1$interval) > 0, 1, 0)
)
new_data3$interval2 <- new_data3$interval^2

predicted_probs3 <- predict(DRmodel_RB1, newdata = new_data3, type = "probs")
predicted_probs_df3 <- as.data.frame(predicted_probs3)
predicted_probs_df3$interval = new_data3$interval
coefs3 <- coef(DRmodel_RB1)
vcov_matrix3 <- vcov(DRmodel_RB1)
coefs_vector3 <- c(coefs3$v1, coefs3$v2, coefs3$v3)
n_sims3 <- 2000
set.seed(30605)
simulated_coefs3 <- mvrnorm(n = n_sims3, mu = coefs_vector3, Sigma = vcov_matrix3)
simulated_probs_matrix3 <- matrix(NA, nrow = n_sims3, ncol = length(predicted_probs_df3$interval) * 3)

softmax3 <- function(eta3) {
  exp_eta3 <- exp(eta3)
  return(exp_eta3 / rowSums(exp_eta3))
}

for(i in 1:n_sims3) {
  sim_coef3 <- matrix(simulated_coefs3[i, ], nrow = 3, byrow = TRUE)
  model_matrix3 <- model.matrix(~ 1+ interval + I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + logwords + after_pandemic, new_data3)
  eta3 <- model_matrix3 %*% t(sim_coef3)
  sim_probs3 <- softmax(eta3)
  simulated_probs_matrix3[i, ] <- as.vector(t(sim_probs3))
}

prob_cis3 <- array(dim = c(2, length(predicted_probs_df3$interval), 3))
for (j in 1:3) {
  for (k in 1:length(predicted_probs_df3$interval)) {
    prob_cis3[ , k, j] <- quantile(simulated_probs_matrix3[, (k-1)*3 + j], probs = c(0.025, 0.975))
  }
}
for (j in 1:3) {
  predicted_probs_df3[[paste0("lower_CI_", j)]] <- prob_cis3[1, , j]
  predicted_probs_df3[[paste0("upper_CI_", j)]] <- prob_cis3[2, , j]
}

predicted_probs_df_less_than_zero3 <- predicted_probs_df3[predicted_probs_df3$interval <= 0, ]
predicted_probs_df_greater_than_zero3 <- predicted_probs_df3[predicted_probs_df3$interval > 0, ]

DR_P3 <- ggplot() +
  geom_ribbon(data = predicted_probs_df_less_than_zero3, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero3, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero3, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero3, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero3, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero3, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_line(data = predicted_probs_df_less_than_zero3, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero3, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_less_than_zero3, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero3, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_less_than_zero3, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero3, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", size = 2) +
  geom_vline(xintercept = 0, color = "#0033FF", linewidth = 1, linetype = "dashed") +
  scale_linetype_manual(
    values = c("Category 1" = "solid", "Category 2" = "dotted", "Category 3" = "dashed"),
    labels = c("Others", "Social Control", "Providing Services")
  ) +
  labs(
    title = "",
    x = "Day",
    y = "Proportion",
    linetype = "Functions (with 95% C.I.)"
  ) +
  theme_bw() +
  theme(
    legend.position = c(0.25,0.85),
    legend.text = element_text(size = 18),
    text = element_text(size = 18),
    plot.title = element_text(size = 18),
    axis.title = element_text(size = 15.5),
    axis.text = element_text(size = 18),
    legend.key.size = unit(1.2, "cm"),
    legend.key.width = unit(2, "cm"),
    legend.spacing.x = unit(1, "cm"),
    legend.spacing.y = unit(0.4, "cm"),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.background = element_blank(),
    panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_x_continuous(
    breaks = c(-50, -25, 0, 25, 50),
    limits = c(-55, 55)
  ) +
  scale_y_continuous(
    breaks = c(0.2, 0.4, 0.6, 0.8),
    limits = c(0, 0.85)
  ) +
  annotate("text", x = -15, y = 0.1, label = "Wuhan Lockdown", hjust = 1, size = 7, color = "black") +
  annotate("segment", x = -14, xend = -1, y = 0.1, yend = 0.1,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 30, y = 0.8, label = "Dirichlet Model\nExcluding the Five Most Prolific GCs\nN = 1,882",
           size = 6, fontface = "bold", colour = "black")

print(DR_P3)

# RB2 (city-level only)
new_local_grid <- data.frame(
  interval = min(data_Robust2$interval):max(data_Robust2$interval),
  nightlight2020 = median(data_Robust2$nightlight2020, na.rm = TRUE),
  loggdppc = median(data_Robust2$loggdppc, na.rm = TRUE),
  logpop = median(data_Robust2$logpop, na.rm = TRUE),
  hubei = 0,
  logwords = median(data_Robust2$logwords),
  after_pandemic = ifelse(min(data_Robust2$interval):max(data_Robust2$interval) > 0, 1, 0)
)
new_local_grid $interval2 <- new_local_grid $interval^2

predicted_probs4 <- predict(DRmodel_RB2, newdata = new_local_grid, type = "probs")
predicted_probs_df4 <- as.data.frame(predicted_probs4)
predicted_probs_df4$interval = new_local_grid$interval
coefs4 <- coef(DRmodel_RB2)
vcov_matrix4 <- vcov(DRmodel_RB2)
coefs_vector4 <- c(coefs4$v1, coefs4$v2, coefs4$v3)
n_sims4 <- 2000
set.seed(30605)
simulated_coefs4 <- mvrnorm(n = n_sims4, mu = coefs_vector4, Sigma = vcov_matrix4)
simulated_probs_matrix4 <- matrix(NA, nrow = n_sims4, ncol = length(predicted_probs_df4$interval) * 3)

softmax4 <- function(eta4) {
  exp_eta4 <- exp(eta4)
  return(exp_eta4 / rowSums(exp_eta4))
}

for(i in 1:n_sims4) {
  sim_coef4 <- matrix(simulated_coefs4[i, ], nrow = 3, byrow = TRUE)
  model_matrix4 <- model.matrix(~ 1+ interval + I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + logwords + after_pandemic, new_local_grid)
  eta4 <- model_matrix4 %*% t(sim_coef4)
  sim_probs4 <- softmax(eta4)
  simulated_probs_matrix4[i, ] <- as.vector(t(sim_probs4))
}

prob_cis4 <- array(dim = c(2, length(predicted_probs_df4$interval), 3))
for (j in 1:3) {
  for (k in 1:length(predicted_probs_df4$interval)) {
    prob_cis4[ , k, j] <- quantile(simulated_probs_matrix4[, (k-1)*3 + j], probs = c(0.025, 0.975))
  }
}

for (j in 1:3) {
  predicted_probs_df4[[paste0("lower_CI_", j)]] <- prob_cis4[1, , j]
  predicted_probs_df4[[paste0("upper_CI_", j)]] <- prob_cis4[2, , j]
}

predicted_probs_df_less_than_zero4 <- predicted_probs_df4[predicted_probs_df4$interval <= 0, ]
predicted_probs_df_greater_than_zero4 <- predicted_probs_df4[predicted_probs_df4$interval > 0, ]

DR_P4 <- ggplot() +
  geom_ribbon(data = predicted_probs_df_less_than_zero4, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero4, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero4, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero4, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero4, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero4, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_line(data = predicted_probs_df_less_than_zero4, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero4, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_less_than_zero4, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero4, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_less_than_zero4, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero4, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", size = 2) +
  geom_vline(xintercept = 0, color = "#0033FF", linewidth = 1, linetype = "dashed") +
  scale_linetype_manual(
    values = c("Category 1" = "solid", "Category 2" = "dotted", "Category 3" = "dashed"),
    labels = c("Others", "Social Control", "Providing Services")
  ) +
  labs(
    title = "",
    x = "Day",
    y = "Proportion",
    linetype = "Functions (with 95% C.I.)"
  ) +
  theme_bw() +
  theme(
    legend.position = c(0.25,0.85),
    legend.text = element_text(size = 18),
    text = element_text(size = 18),
    plot.title = element_text(size = 18),
    axis.title = element_text(size = 15.5),
    axis.text = element_text(size = 18),
    legend.key.size = unit(1.2, "cm"),
    legend.key.width = unit(2, "cm"),
    legend.spacing.x = unit(1, "cm"),
    legend.spacing.y = unit(0.4, "cm"),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.background = element_blank(),
    panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_x_continuous(
    breaks = c(-50, - 25, 0, 25, 50),
    limits = c(-55, 55)
  )+
  scale_y_continuous(
    breaks = c(0.2, 0.4, 0.6, 0.8),
    limits = c(0, 0.85)
  )+
  annotate("text", x = -15, y = 0.1, label = "Wuhan Lockdown", hjust = 1, size = 7, color = "black") +
  annotate("segment", x = -14, xend = -1, y = 0.1, yend = 0.1,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 30, y = 0.8, label = "Dirichlet Model\n Six Municipal-level GCs\n N = 751", size = 7, fontface = "bold", colour = "black")

print(DR_P4)

# RB3 (county and below)
new_data5 <- data.frame(
  interval = min(data_Robust3$interval):max(data_Robust3$interval),
  nightlight2020 = median(data_Robust3$nightlight2020, na.rm = TRUE),
  loggdppc = median(data_Robust3$loggdppc, na.rm = TRUE),
  logpop = median(data_Robust3$logpop, na.rm = TRUE),
  hubei = 0,
  logwords = median(data_Robust3$logwords),
  after_pandemic = ifelse(min(data_Robust3$interval):max(data_Robust3$interval) > 0, 1, 0)
)
new_data5$interval2 <-new_data5$interval^2

predicted_probs5 <- predict(DRmodel_RB3, newdata = new_data5, type = "probs")
predicted_probs_df5 <- as.data.frame(predicted_probs5)
predicted_probs_df5$interval = new_data5$interval
coefs5 <- coef(DRmodel_RB3)
vcov_matrix5 <- vcov(DRmodel_RB3)
coefs_vector5 <- c(coefs5$v1, coefs5$v2, coefs5$v3)
n_sims5 <- 2000
set.seed(30605)
simulated_coefs5 <- mvrnorm(n = n_sims5, mu = coefs_vector5, Sigma = vcov_matrix5)
simulated_probs_matrix5 <- matrix(NA, nrow = n_sims5, ncol = length(predicted_probs_df5$interval) * 3)

softmax5 <- function(eta5) {
  exp_eta5 <- exp(eta5)
  return(exp_eta5 / rowSums(exp_eta5))
}

for(i in 1:n_sims5) {
  sim_coef5 <- matrix(simulated_coefs5[i, ], nrow = 3, byrow = TRUE)
  model_matrix5 <- model.matrix(~ 1+ interval+ I(interval^2) + nightlight2020 + loggdppc + logpop + hubei + logwords + after_pandemic, new_data5)
  eta5 <- model_matrix5 %*% t(sim_coef5)
  sim_probs5 <- softmax(eta5)
  simulated_probs_matrix5[i, ] <- as.vector(t(sim_probs5))
}

prob_cis5 <- array(dim = c(2, length(predicted_probs_df5$interval), 3))
for (j in 1:3) {
  for (k in 1:length(predicted_probs_df5$interval)) {
    prob_cis5[ , k, j] <- quantile(simulated_probs_matrix5[, (k-1)*3 + j], probs = c(0.025, 0.975))
  }
}

for (j in 1:3) {
  predicted_probs_df5[[paste0("lower_CI_", j)]] <- prob_cis5[1, , j]
  predicted_probs_df5[[paste0("upper_CI_", j)]] <- prob_cis5[2, , j]
}

predicted_probs_df_less_than_zero5 <- predicted_probs_df5[predicted_probs_df5$interval <= 0, ]
predicted_probs_df_greater_than_zero5 <- predicted_probs_df5[predicted_probs_df5$interval > 0, ]

DR_P5 <- ggplot() +
  geom_ribbon(data = predicted_probs_df_less_than_zero5, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero5, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_less_than_zero5, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero5, aes(x = interval, ymin = lower_CI_1, ymax = upper_CI_1), fill = "grey60", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero5, aes(x = interval, ymin = lower_CI_2, ymax = upper_CI_2), fill = "grey40", alpha = 0.2) +
  geom_ribbon(data = predicted_probs_df_greater_than_zero5, aes(x = interval, ymin = lower_CI_3, ymax = upper_CI_3), fill = "grey20", alpha = 0.2) +
  geom_line(data = predicted_probs_df_less_than_zero5, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero5, aes(x = interval, y = V1, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_less_than_zero5, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero5, aes(x = interval, y = V2, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_less_than_zero5, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", size = 2) +
  geom_line(data = predicted_probs_df_greater_than_zero5, aes(x = interval, y = V3, linetype = "Category 3"), colour = "black", size = 2) +
  geom_vline(xintercept = 0, color = "#0033FF", linewidth = 1, linetype = "dashed") +
  scale_linetype_manual(
    values = c("Category 1" = "solid", "Category 2" = "dotted", "Category 3" = "dashed"),
    labels = c("Others", "Social Control", "Providing Services")
  ) +
  labs(
    title = "",
    x = "Day",
    y = "Proportion",
    linetype = "Functions (with 95% C.I.)"
  ) +
  theme_bw() +
  theme(
    legend.position = c(0.25,0.85),
    legend.text = element_text(size = 18),
    text = element_text(size = 18),
    plot.title = element_text(size = 18),
    axis.title = element_text(size = 15.5),
    axis.text = element_text(size = 18),
    legend.key.size = unit(1.2, "cm"),
    legend.key.width = unit(2, "cm"),
    legend.spacing.x = unit(1, "cm"),
    legend.spacing.y = unit(0.4, "cm"),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.background = element_blank(),
    panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_x_continuous(
    breaks = c(-50, - 25, 0, 25, 50),
    limits = c(-55, 55)
  )+
  scale_y_continuous(
    breaks = c(0.2, 0.4, 0.6, 0.8),
    limits = c(0, 0.85)
  )+
  annotate("text", x = -15, y = 0.1, label = "Wuhan Lockdown", hjust = 1, size = 7, color = "black") +
  annotate("segment", x = -14, xend = -1, y = 0.1, yend = 0.1,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 30, y = 0.8, label = "Dirichlet Model\n 48 County or Lower-level GCs\n N = 2,083", size = 7, fontface = "bold", colour = "black")

print(DR_P5)

# --- Multilevel beta regression visualization ---
ml_data1 <- data.frame(
  interval = min(local_grid$interval):max(local_grid$interval),
  nightlight2020 = median(account_data$nightlight2020, na.rm = TRUE),
  loggdppc = median(account_data$loggdppc, na.rm = TRUE),
  logpop = median(account_data$logpop, na.rm = TRUE),
  hubei = 0,
  logwords = median(local_grid$logwords),
  after_pandemic = ifelse(min(local_grid$interval):max(local_grid$interval) > 0, 1, 0),
  GC = rep("网格台州", length(min(local_grid$interval):max(local_grid$interval)))
)
ml_data1$interval2 <- ml_data1$interval^2

# Predicted means and 95% CI for O/C/S under beta models
ml_data1$predicted_O <- predict(beta_O, newdata = ml_data1, type = "response", allow.new.levels = TRUE)
ml_data1$predicted_O_se <- predict(beta_O, newdata = ml_data1, type = "response", se.fit = TRUE, allow.new.levels = TRUE)$se.fit
ml_data1$ci_lower_O <- ml_data1$predicted_O - 1.96 * ml_data1$predicted_O_se
ml_data1$ci_upper_O <- ml_data1$predicted_O + 1.96 * ml_data1$predicted_O_se

ml_data1$predicted_C <- predict(beta_C, newdata = ml_data1, type = "response", allow.new.levels = TRUE)
ml_data1$predicted_C_se <- predict(beta_C, newdata = ml_data1, type = "response", se.fit = TRUE, allow.new.levels = TRUE)$se.fit
ml_data1$ci_lower_C <- ml_data1$predicted_C - 1.96 * ml_data1$predicted_C_se
ml_data1$ci_upper_C <- ml_data1$predicted_C + 1.96 * ml_data1$predicted_C_se

ml_data1$predicted_S <- predict(beta_S, newdata = ml_data1, type = "response", allow.new.levels = TRUE)
ml_data1$predicted_S_se <- predict(beta_S, newdata = ml_data1, type = "response", se.fit = TRUE, allow.new.levels = TRUE)$se.fit
ml_data1$ci_lower_S <- ml_data1$predicted_S - 1.96 * ml_data1$predicted_S_se
ml_data1$ci_upper_S <- ml_data1$predicted_S + 1.96 * ml_data1$predicted_S_se

ml_data1_before_COVID <- ml_data1[ml_data1$interval <= 0, ]
ml_data1_after_COVID <- ml_data1[ml_data1$interval > 0, ]

# Panel for beta-model predictions (Others/Control/Service)
MLRB <- ggplot() +
  geom_ribbon(data = ml_data1_before_COVID, aes(x = interval, ymin = ci_lower_O, ymax = ci_upper_O), fill = "grey80", alpha = 0.2) +
  geom_ribbon(data = ml_data1_before_COVID, aes(x = interval, ymin = ci_lower_C, ymax = ci_upper_C), fill = "grey50", alpha = 0.2) +
  geom_ribbon(data = ml_data1_before_COVID, aes(x = interval, ymin = ci_lower_S, ymax = ci_upper_S), fill = "grey20", alpha = 0.2) +
  geom_ribbon(data = ml_data1_after_COVID, aes(x = interval, ymin = ci_lower_O, ymax = ci_upper_O), fill = "grey80", alpha = 0.2) +
  geom_ribbon(data = ml_data1_after_COVID, aes(x = interval, ymin = ci_lower_C, ymax = ci_upper_C), fill = "grey50", alpha = 0.2) +
  geom_ribbon(data = ml_data1_after_COVID, aes(x = interval, ymin = ci_lower_S, ymax = ci_upper_S), fill = "grey20", alpha = 0.2) +
  geom_line(data = ml_data1_before_COVID, aes(x = interval, y = predicted_O, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = ml_data1_after_COVID, aes(x = interval, y = predicted_O, linetype = "Category 1"), colour = "black", size = 2) +
  geom_line(data = ml_data1_before_COVID, aes(x = interval, y = predicted_C, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = ml_data1_after_COVID, aes(x = interval, y = predicted_C, linetype = "Category 2"), colour = "black", size = 2) +
  geom_line(data = ml_data1_before_COVID, aes(x = interval, y = predicted_S, linetype = "Category 3"), colour = "black", size = 2) +
  geom_line(data = ml_data1_after_COVID, aes(x = interval, y = predicted_S, linetype = "Category 3"), colour = "black", size = 2) +
  geom_vline(xintercept = 0, color = "#0033FF", linewidth = 1, linetype = "dashed") +
  scale_linetype_manual(
    values = c("Category 1" = "solid", "Category 2" = "dotted", "Category 3" = "dashed"),
    labels = c("Others", "Social Control", "Providing Services")
  ) +
  labs(
    title = "",
    x = "Day",
    y = "Proportion",
    linetype = "Functions (with 95% C.I.)"
  ) +
  theme_bw() +
  theme(
    legend.position = c(0.25,0.85),
    legend.text = element_text(size = 18),
    text = element_text(size = 18),
    plot.title = element_text(size = 18),
    axis.title = element_text(size = 15.5),
    axis.text = element_text(size = 18),
    legend.key.size = unit(1.2, "cm"),
    legend.key.width = unit(2, "cm"),
    legend.spacing.x = unit(1, "cm"),
    legend.spacing.y = unit(0.4, "cm"),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.background = element_blank(),
    panel.border = element_rect(colour = "black", fill = NA, size = 1)
  ) +
  scale_x_continuous(
    breaks = c(-50, - 25, 0, 25, 50),
    limits = c(-55, 55)
  )+
  scale_y_continuous(
    breaks = c(0.2, 0.4, 0.6, 0.8),
    limits = c(0, 0.85)
  )+
  annotate("text", x = -15, y = 0.1, label = "Wuhan Lockdown", hjust = 1, size = 7, color = "black") +
  annotate("segment", x = -14, xend = -1, y = 0.1, yend = 0.1,
           colour = "black", linewidth = 1, arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 30, y = 0.8, label = "Multilevel Beta Regression\n for Each Function \n N = 2,834 for each model", size = 7, fontface = "bold", colour = "black")

print(MLRB)

# --- Figure A8: Robustness panels (Dirichlet RB1–RB3 + Multilevel Beta) ---
title_textA5 <- "FIGURE A8-1. Robustness Check of the Dirichlet Model"
title_grobA5 <- grobTree(
  textGrob(title_textA5,
           gp = gpar(fontsize = 24, fontface = "bold"), y = -0.6)
)

figureA8 <- grid.arrange(
  DR_P3,
  DR_P4,
  DR_P5,
  MLRB,
  ncol = 2
)

ggsave(filename = "plots/FIGUREA8.png", plot = figureA8, width = 20.5, height = 16.4, units = "in", dpi = 1200)
